{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5e76095a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "device= cuda\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "torch.Size([16, 2])"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "#快速演示\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "print('device=', device)\n",
    "\n",
    "from transformers import BertModel\n",
    "\n",
    "#加载预训练模型\n",
    "pretrained = BertModel.from_pretrained('bert-base-chinese')\n",
    "#需要移动到cuda上\n",
    "pretrained.to(device)\n",
    "\n",
    "#不训练,不需要计算梯度\n",
    "for param in pretrained.parameters():\n",
    "    param.requires_grad_(False)\n",
    "\n",
    "\n",
    "#定义下游任务模型\n",
    "class Model(torch.nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.fc = torch.nn.Linear(768, 2)\n",
    "\n",
    "    def forward(self, input_ids, attention_mask, token_type_ids):\n",
    "        with torch.no_grad():\n",
    "            out = pretrained(input_ids=input_ids,\n",
    "                             attention_mask=attention_mask,\n",
    "                             token_type_ids=token_type_ids)\n",
    "\n",
    "        out = self.fc(out.last_hidden_state[:, 0])\n",
    "        out = out.softmax(dim=1)\n",
    "        return out\n",
    "\n",
    "\n",
    "model = Model()\n",
    "#同样要移动到cuda\n",
    "model.to(device)\n",
    "\n",
    "#虚拟一批数据,需要把所有的数据都移动到cuda上\n",
    "input_ids = torch.ones(16, 100).long().to(device)\n",
    "attention_mask = torch.ones(16, 100).long().to(device)\n",
    "token_type_ids = torch.ones(16, 100).long().to(device)\n",
    "labels = torch.ones(16).long().to(device)\n",
    "\n",
    "#试算\n",
    "model(input_ids=input_ids,\n",
    "      attention_mask=attention_mask,\n",
    "      token_type_ids=token_type_ids).shape\n",
    "\n",
    "#后面的计算和中文分类完全一样，只是放在了cuda上计算"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3362a434",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(9600,\n",
       " ('选择珠江花园的原因就是方便，有电动扶梯直接到达海边，周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般，但还算整洁。 泳池在大堂的屋顶，因此很小，不过女儿倒是喜欢。 包的早餐是西式的，还算丰富。 服务吗，一般',\n",
       "  1))"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import load_dataset, load_from_disk\n",
    "\n",
    "\n",
    "#定义数据集\n",
    "class Dataset(torch.utils.data.Dataset):\n",
    "\n",
    "    def __init__(self, split):\n",
    "        #self.dataset = load_dataset(path='seamew/ChnSentiCorp', split=split)\n",
    "        self.dataset = load_from_disk('./data/ChnSentiCorp')[split]\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.dataset)\n",
    "\n",
    "    def __getitem__(self, i):\n",
    "        text = self.dataset[i]['text']\n",
    "        label = self.dataset[i]['label']\n",
    "\n",
    "        return text, label\n",
    "\n",
    "\n",
    "dataset = Dataset('train')\n",
    "\n",
    "len(dataset), dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e59695a4",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "600\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(torch.Size([16, 500]),\n",
       " torch.Size([16, 500]),\n",
       " torch.Size([16, 500]),\n",
       " tensor([1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1], device='cuda:0'))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import BertTokenizer\n",
    "\n",
    "#加载字典和分词工具\n",
    "token = BertTokenizer.from_pretrained('bert-base-chinese')\n",
    "\n",
    "\n",
    "def collate_fn(data):\n",
    "    sents = [i[0] for i in data]\n",
    "    labels = [i[1] for i in data]\n",
    "\n",
    "    #编码\n",
    "    data = token.batch_encode_plus(batch_text_or_text_pairs=sents,\n",
    "                                   truncation=True,\n",
    "                                   padding='max_length',\n",
    "                                   max_length=500,\n",
    "                                   return_tensors='pt',\n",
    "                                   return_length=True)\n",
    "\n",
    "    #input_ids:编码之后的数字\n",
    "    #attention_mask:是补零的位置是0,其他位置是1\n",
    "    input_ids = data['input_ids'].to(device)\n",
    "    attention_mask = data['attention_mask'].to(device)\n",
    "    token_type_ids = data['token_type_ids'].to(device)\n",
    "    labels = torch.LongTensor(labels).to(device)\n",
    "\n",
    "    #print(data['length'], data['length'].max())\n",
    "\n",
    "    return input_ids, attention_mask, token_type_ids, labels\n",
    "\n",
    "\n",
    "#数据加载器\n",
    "loader = torch.utils.data.DataLoader(dataset=dataset,\n",
    "                                     batch_size=16,\n",
    "                                     collate_fn=collate_fn,\n",
    "                                     shuffle=True,\n",
    "                                     drop_last=True)\n",
    "\n",
    "for i, (input_ids, attention_mask, token_type_ids,\n",
    "        labels) in enumerate(loader):\n",
    "    break\n",
    "\n",
    "print(len(loader))\n",
    "input_ids.shape, attention_mask.shape, token_type_ids.shape, labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "1bd44a7c",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/anaconda3/envs/pt/lib/python3.6/site-packages/transformers/optimization.py:309: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
      "  FutureWarning,\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0.7148881554603577 0.6875\n",
      "5 0.643970787525177 0.8125\n",
      "10 0.7154413461685181 0.375\n",
      "15 0.639808177947998 0.6875\n",
      "20 0.6095095872879028 0.8125\n",
      "25 0.6154075860977173 0.5\n",
      "30 0.5432500839233398 0.9375\n",
      "35 0.5729296207427979 0.8125\n",
      "40 0.659771203994751 0.5625\n",
      "45 0.5986269116401672 0.6875\n",
      "50 0.5581840872764587 0.8125\n",
      "55 0.5056470036506653 0.875\n",
      "60 0.49638205766677856 0.9375\n",
      "65 0.625960111618042 0.75\n",
      "70 0.4277273416519165 1.0\n",
      "75 0.5182679295539856 0.8125\n",
      "80 0.5487587451934814 0.8125\n",
      "85 0.48724329471588135 0.9375\n",
      "90 0.48124760389328003 0.8125\n",
      "95 0.5316541194915771 0.8125\n",
      "100 0.5484813451766968 0.75\n"
     ]
    }
   ],
   "source": [
    "from transformers import AdamW\n",
    "\n",
    "#训练\n",
    "optimizer = AdamW(model.parameters(), lr=5e-4)\n",
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "model.train()\n",
    "for i, (input_ids, attention_mask, token_type_ids,\n",
    "        labels) in enumerate(loader):\n",
    "    out = model(input_ids=input_ids,\n",
    "                attention_mask=attention_mask,\n",
    "                token_type_ids=token_type_ids)\n",
    "\n",
    "    loss = criterion(out, labels)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    optimizer.zero_grad()\n",
    "\n",
    "    if i % 5 == 0:\n",
    "        out = out.argmax(dim=1)\n",
    "        accuracy = (out == labels).sum().item() / len(labels)\n",
    "\n",
    "        print(i, loss.item(), accuracy)\n",
    "\n",
    "    if i == 100:\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "275dd1b6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "1\n",
      "2\n",
      "3\n",
      "4\n",
      "0.85625\n"
     ]
    }
   ],
   "source": [
    "#测试\n",
    "def test():\n",
    "    model.eval()\n",
    "    correct = 0\n",
    "    total = 0\n",
    "\n",
    "    loader_test = torch.utils.data.DataLoader(dataset=Dataset('validation'),\n",
    "                                              batch_size=32,\n",
    "                                              collate_fn=collate_fn,\n",
    "                                              shuffle=True,\n",
    "                                              drop_last=True)\n",
    "\n",
    "    for i, (input_ids, attention_mask, token_type_ids,\n",
    "            labels) in enumerate(loader_test):\n",
    "\n",
    "        if i == 5:\n",
    "            break\n",
    "\n",
    "        print(i)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            out = model(input_ids=input_ids,\n",
    "                        attention_mask=attention_mask,\n",
    "                        token_type_ids=token_type_ids)\n",
    "\n",
    "        out = out.argmax(dim=1)\n",
    "        correct += (out == labels).sum().item()\n",
    "        total += len(labels)\n",
    "\n",
    "    print(correct / total)\n",
    "\n",
    "\n",
    "test()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
